"""The purpose of this module is to provide a server layer which passes all
outputs available on the current node to any requesting node. The requests
are received in the form of broadcasts and the file transfer takes place
through TCP sockets

"""

import socketserver
import threading
import queue
import os.path
import re
from . import constants
import ma.net.tcpsv
import ma.const
import ma.log
import ma.net.brdcstsv
from . import filetransferinfo
from ma.net import tcpsv
from threading import Lock


OUTPUT_ID_SEP = ','
PCKT_MSG = '%s "%s"'
OUTPUT_AVAILABILITY_TAG = 'OA'
OUTPUT_IS_AVAILABLE_TAG = 'OI' 
REQUEST_FILE_TAG = 'RQ'

# buffer sizes
command_buff_size = ma.const.XmlData.get_int_data(ma.const.xml_net_command_buff_size)
file_transfer_buff_size = ma.const.XmlData.get_int_data(ma.const.xml_net_file_transfer_buff_size)

# stats transfer bytes
stats_bytes_sent = 0

# the first no. is job id, 2. Reduce or Map, 3. the task id, 4. destination task id
FILE_ID_REGEXP = r'(\d+)([RM])(\d+)_(\d+)'
CLIENT_PCKT_REGEXP = r'([a-zA-Z]{1,2})\s+"((' + FILE_ID_REGEXP + r',?)+)"'

# file transfer id for the info object
file_transfer_id = constants.FILE_TRANSFER_ID_DEFAULT



class SendOutputHandler(socketserver.BaseRequestHandler):
    """Socket Handler which would be passed to the Threaded TCP server
    """
                                                       
    def __init__(self, request, client_address, server):
        # initialize __log for Send Output Handler
        self.__log = ma.log.get_logger("ma.commons")
        
        try:
            self.__log.debug('Send Output Handler starting to serve a client')
            
            # get the nodetracker reference
            self.__internal_nt_ref = server.args_list[0]
            
            # stats lock reference
            self.__stats_lock = server.args_list[1]
            
            # file stats info queue
            self.__file_stats_queue = server.args_list[2] 
            
            socketserver.BaseRequestHandler.__init__(self, request, client_address, server)
        except Exception as err_msg:
            self.__log.error("Error while creating Send Output Handler: %s", err_msg)
        

    def handle(self):
        """This function is derived from BaseRequestHandler to deal with every
        connecting client
        """
        
        global file_transfer_id
        
        # Echo the back to the client
        try:
            cur_thread = threading.currentThread()
            #response = '%s: %s' % (cur_thread.getName(), data)
            
            last_msg_ip = self.client_address[0]
            last_msg_pckt = self.request.recv(command_buff_size)
            
            # split packet into the packet tag and the resource request part
            pckt_split = re.findall(CLIENT_PCKT_REGEXP, last_msg_pckt)
            
            pckt_correctly_recv = False
            
            # check if the packet is valid according to the regular expression
            if len(pckt_split) == 1:
                # get the valid info from the packet
                pckt_split = pckt_split[0]
                pckt_tag = pckt_split[0]
                pckt_needed_file = pckt_split[1]
                
                if pckt_tag != '' and pckt_needed_file != '':
                    if pckt_tag == REQUEST_FILE_TAG:
                        job_id, map_or_red, task_id, dest_task_id = pckt_split[3:7]
                        
                        job_id = int(job_id)
                        task_id = int(task_id)
                        dest_task_id = int(dest_task_id)
                        
                        filepath = self.__internal_nt_ref.getCompleteTaskOutputPath(job_id, map_or_red, task_id, dest_task_id)
                        # check if the input is available
                        if filepath == None:
                            self.__log.error('The file for %s task %s requested by client is not available at this node Job-id %s input for task %s', str(map_or_red), str(task_id), str(job_id), str(dest_task_id))
                        else:
                            # get file size
                            size = os.path.getsize(filepath)
                            
                            # create file stats object
                            with self.__stats_lock:
                                transfer_info = filetransferinfo.FileTransferInfo(file_transfer_id, filepath, size, job_id)
                                # adjust the id for next transfer
                                file_transfer_id = filetransferinfo.get_next_id(file_transfer_id)
                                
                            # transfer the file to the requesting client
                            self.__log.info('starting file transfer to %s: %s' % (last_msg_ip, filepath))
                            transfer_info.note_transfer_start()
                            self.__transfer_file(filepath)
                            transfer_info.note_transfer_end()
                            self.__log.info('file transfer to %s complete: %s' % (last_msg_ip, filepath))
                            
                            # insert file transfer stats into a common queue
                            self.__file_stats_queue.put(transfer_info)
                                               
                        pckt_correctly_recv = True

            if pckt_correctly_recv == False:
                self.__log.warning("Incorrect or corrupt packet received from %s: %s", last_msg_ip, last_msg_pckt)
        except Exception as err_msg:
            self.__log.error("Error while serving client: %s", err_msg)
    
    
    def __transfer_file(self, filepath):
        """This function is called by the handle function to transfer files
        """
        
        # for editing the global stats variable
        global stats_bytes_sent
        
        try:
            # get file size
            sz = os.path.getsize(filepath)
            
            # open file
            fd = open(filepath, mode='r')
            
            # read the initial chunk
            data = fd.read(file_transfer_buff_size)
            
            if not data:
                self.__log.error("File found empty during read: %s, OS returned file size %d", filepath, sz)
            
            # read the data
            while data:
                # transfer the buffer read
                bytes_sent = self.request.send(data)
                
                if bytes_sent <= 0 or bytes_sent != len(data):
                    self.__log.error("Unable to send any or all bytes for file transfer %s: Bytes already transferred %d out of %d", filepath, fd.tell(), sz)
                
                # update number of bytes senf
                with self.__stats_lock:
                    stats_bytes_sent += bytes_sent
                
                data = fd.read(file_transfer_buff_size)
        except Exception as err_msg:
            self.__log.error("Error while transferring file to client: %s", err_msg)

        
    def finish(self):
        self.__log.debug('Send Output Handler: finished serving client')
        return socketserver.BaseRequestHandler.finish(self)


class OutputServer(object):
    """Output server thread will instantiate a threaded TCP Server and a UDP
    Broadcast Listen Server, whose task would be to listen client requests 
    from other nodes, asking for outputs to specific tasks. This server also 
    deals with File Transfers for requested output files
    """
    
    def __init__(self, nodetracker_ref):
        """Constructor for creating and running Threaded TCP Server. It passes
        the SendOuputHandler as a handler to Threaded TCP Server's client
        requests.
        """
        
        #initialize __log for Output Server
        self.__log = ma.log.get_logger("ma.commons")
            
        try:
            self.__internal_nt_ref = nodetracker_ref
            
            # this dictionary is list of ips against every task identifier (key)
            # it used to maintain a list of nodes having the said output files
            self.__output_src_dict = {}
            
            # lock used to update the number of bytes sent
            self.__stats_lock = Lock()
            self.__file_stats_queue = queue.Queue()
            self.file_transfer_id = constants.FILE_TRANSFER_ID_DEFAULT 
             
            # get the binding address
            binding_addr = str.strip(ma.const.XmlData.get_str_data(ma.const.xml_tcp_out_srv_bind_addr))
            self.__output_comm_port = ma.const.XmlData.get_int_data(ma.const.xml_net_output_srv_bind_port)
            broadcast_addr = str.strip(ma.const.XmlData.get_str_data(ma.const.xml_broadcast_address))
            
            # instantiate the TCP server thread
            self.__server = tcpsv.ThreadedTCPServer(binding_addr, self.__output_comm_port, SendOutputHandler, [self.__internal_nt_ref, self.__stats_lock, self.__file_stats_queue])
            
            # create thread for serving clients forever
            self.__srv_th = threading.Thread(target=self.__server.serve_forever)
            # exit python program if only daemon threads are left
            #self.__srv_th.setDaemon(True)
            
            # start the Threaded TCP Server
            self.__srv_th.start()
            
            self.__log.debug('started TCP Threaded Server')

            # instantiate the UDP Broadcast listen server
            # Note that we are going to bind with the same address ... since this is UDP side
            self.__brdcst_listen_srvr = ma.net.brdcstsv.BroadcastListenServer(binding_addr, self.__output_comm_port, None, True, True)
            callback_thread = threading.Thread(target=self.__query_callback_thread)
            callback_thread.start()
            self.__brdcst_listen_srvr.start()
            self.__log.debug('started UDP Broadcast Listen Server')
            
            self.__udp_pckt_sender = ma.net.brdcstsv.BroadcastPacketSender()
            self.__log.debug('started UDP Packet Sender')
            
            self.__log.info('completed init of Output Server')
        except Exception as err_msg:
            self.__log.error("Error while creating TCP Output Server: %s", err_msg)
           
    
    def __query_callback_thread(self):
        """This functions runs an infinite loop which will process the queue.
        The queue 'self.__brdcst_listen_srvr.msg_queue.qsize()' holds the 
        request received by broadcast_listner. This method is a copy of query 
        call back method it has a few modifications in it.
        """
        
        while True:        
            # call callback function to interpret the incoming UDP packet
            self.__broadcast_query_callback()
            
            
    def __broadcast_query_callback(self):
        """This will receive broadcast requests querying the presence of
        outputs at this node. If output present a response is sent telling
        which files are available 
        """
        
        try:
            # get the last message that was received through UDP
            last_msg = self.__brdcst_listen_srvr.msg_queue.get()
            last_msg_ip = last_msg[0]
            last_msg_pckt = last_msg[1]
            
            # split packet into the packet tag and the resource request part
            pckt_split = re.findall(CLIENT_PCKT_REGEXP, last_msg_pckt)
            
            self.__log.debug('Broadcast packet received: %s || splits: %s', last_msg, str(pckt_split))
            
            pckt_correctly_recv = False
            
            if len(pckt_split) == 1:
                # get the valid info from the packet
                pckt_split = pckt_split[0]
                pckt_tag = pckt_split[0]
                pckt_needed_files = pckt_split[1]
                
                if pckt_tag != '' and pckt_needed_files != '':
                    if pckt_tag == OUTPUT_AVAILABILITY_TAG:
                        # Packet received for output availability request 
                        # used to store which outputs are available
                        available_outputs = ''
                        
                        needed_files = pckt_needed_files.split(OUTPUT_ID_SEP)
                        
                        # iterate through all the resources needed
                        for fl in needed_files:
                            # find pieces of info for each file
                            file_desc = re.findall(FILE_ID_REGEXP, fl)
                            job_id, map_or_red, task_id, dest_task_id = file_desc[0]
                            
                            job_id = int(job_id)
                            task_id = int(task_id)
                            dest_task_id = int(dest_task_id)
                             
                            # check if the input is available
                            if self.__internal_nt_ref.getCompleteTaskOutputPath(job_id, map_or_red, task_id, dest_task_id) != None:
                                if available_outputs != '':
                                    available_outputs += OUTPUT_ID_SEP
                                self.__log.info("Found file %s, for ip %s" % (fl,str(last_msg_ip)))
                                available_outputs += fl
                            else: 
                                self.__log.info("Could not find file %s" % (fl))
                        
                        self.__log.info("Available Outputs %s , %s",available_outputs,last_msg_ip)
                        
                        # if any output is available, reply
                        if available_outputs != '':
                            reply_msg = PCKT_MSG % (OUTPUT_IS_AVAILABLE_TAG, available_outputs) 
                            self.__log.info("Sending Response %s %s" % (reply_msg,last_msg_ip))
                            self.__udp_pckt_sender.send_message(reply_msg, (last_msg_ip, self.__output_comm_port))
                        
                        self.__log.info("Received an Output Availability query packet for %s from %s"  % (needed_files, str(last_msg_ip)))
                        
                        # mark as valid packet 
                        pckt_correctly_recv = True
                    
                    elif pckt_tag == OUTPUT_IS_AVAILABLE_TAG:
                        # Packet received for output availability response
                        
                        available_files = pckt_needed_files.split(OUTPUT_ID_SEP)
                        # iterate through all the resources returned
                        for fl in available_files:
                            # find pieces of info for each file
                            file_desc = re.findall(FILE_ID_REGEXP, fl)
                            task_identifier = file_desc[0]

                            adjusted_tsk_id = (int(task_identifier[0]), task_identifier[1], int(task_identifier[2]), int(task_identifier[3]))
                            
                            # check if the input is available
                            if adjusted_tsk_id not in self.__output_src_dict:
                                # if found the output for the first time
                                self.__output_src_dict[adjusted_tsk_id] = [last_msg_ip]
                            else:
                                # if found the output for the first time
                                if last_msg_ip not in self.__output_src_dict[adjusted_tsk_id]:
                                    self.__output_src_dict[adjusted_tsk_id].append(last_msg_ip)
                                
                        self.__log.info("Received an Output Availability response packet for %s from %s", available_files, str(last_msg_ip))
                        self.__log.debug("Current state of resources dictionary: %s", str(self.__output_src_dict))
                        
                        # mark as valid packet 
                        pckt_correctly_recv = True
            
            if pckt_correctly_recv == False:
                self.__log.warning("Incorrect or corrupt packet received from %s: %s", last_msg_ip, last_msg_pckt)
                
        except Exception as err_msg:
            self.__log.error("Error while handling user broadcast query: %s", err_msg)
    
        
    def return_output_host_ip(self, job_id, map_or_red, task_id, dest_task_id):
        """This returns the first IP in the list of available hosts. Else
        returns None if no host is available
        """
        
        tpl = (job_id, map_or_red, task_id, dest_task_id)
        if tpl in self.__output_src_dict:
            return self.__output_src_dict[tpl][0]
        else:
            return None
    
    
    def return_bytes_transferred(self):
        """This function returns the cummulative number of bytes transferred
        from this server since startup. Essentially it will give the total
        bytes shuffled from this node
        """
        
        return stats_bytes_sent
    

    def flush_file_transfer_infos(self, job_id=None):
        """Returns all the FileTransferInfo items in the internal queue. It
        also empties the queue. If the job_id is given it will only give the 
        FileTransferInfo objs for that particular job
        """
        
        return_list = []
        # get the current size
        sz = self.__file_stats_queue.qsize()
        idx = 0
        
        # iterate the queue till the estimated size or till empty 
        while idx < sz and not self.__file_stats_queue.empty():
            filetxinfo = self.__file_stats_queue.get()
            if job_id != None and job_id != filetxinfo.job_id:
                # push back in queue again if the job_id doesnt match
                self.__file_stats_queue.put(filetxinfo)
            else:
                # if job_id == None then push all, or job_id matches
                return_list.append(filetxinfo)
                
            # increment size iterated
            idx += 1
        
        return return_list
    
    
    def __del__(self):
        """Destructor to the Threaded TCP Server
        """
        
        try:
            self.__log.info('Shutting down TCP Output Server')
        except Exception as err_msg:
            self.__log.error("Error while trying to destruct the OutputServerThread: %s", err_msg)



class DummyNodeTracker(object):
    """For testing outputs and file transfer"""
    
    def getCompleteTaskOutputPath(self, job_id, map_or_red, task_id, dest_task_id):
        import os.path
        import os
        output_dir_path = ma.const.JobsXmlData.get_filepath_str_data(ma.const.xml_local_output_temp_dir, job_id)
        if constants.MAP == map_or_red:
            filepath = output_dir_path + os.sep + ma.const.JobsXmlData.get_str_data(ma.const.xml_map_output_filename, job_id, task_id)
        else:
            filepath = output_dir_path + os.sep + ma.const.JobsXmlData.get_str_data(ma.const.xml_reduce_output_filename, job_id, task_id)
        
        if os.path.exists(filepath):
            return filepath
        else:
            return None
        


if __name__ == '__main__':
    #main for testing
    nt = DummyNodeTracker()
    outsvr = OutputServer(nodetracker_ref=nt)
